package com.alibaba.alink.operator.batch.dataproc;

import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

import com.alibaba.alink.common.annotation.InputPorts;
import com.alibaba.alink.common.annotation.NameCn;
import com.alibaba.alink.common.annotation.OutputPorts;
import com.alibaba.alink.common.annotation.ParamSelectColumnSpec;
import com.alibaba.alink.common.annotation.PortDesc;
import com.alibaba.alink.common.annotation.PortSpec;
import com.alibaba.alink.common.annotation.PortType;
import com.alibaba.alink.common.annotation.ReservedColsWithFirstInputSpec;
import com.alibaba.alink.common.annotation.TypeCollections;
import com.alibaba.alink.common.utils.OutputColsHelper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.params.dataproc.HugeMultiStringIndexerPredictParams;
import org.apache.commons.lang.StringUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;

/**
 * Map string to index based on the model generated by {@link MultiStringIndexerTrainBatchOp}.
 */
@InputPorts(values = {
	@PortSpec(value = PortType.MODEL, desc = PortDesc.PREDICT_INPUT_MODEL),
	@PortSpec(value = PortType.DATA, desc = PortDesc.PREDICT_INPUT_DATA)
})
@OutputPorts(values = {@PortSpec(value = PortType.DATA, desc = PortDesc.OUTPUT_RESULT)})
@ParamSelectColumnSpec(name = "selectedCols",
	allowedTypeCollections = TypeCollections.LONG_TYPES)
@ReservedColsWithFirstInputSpec
@NameCn("并行ID化预测")
public final class HugeIndexerStringPredictBatchOp
	extends BatchOperator <HugeIndexerStringPredictBatchOp>
	implements HugeMultiStringIndexerPredictParams<HugeIndexerStringPredictBatchOp> {

	private static final long serialVersionUID = -794572755838107745L;

	public HugeIndexerStringPredictBatchOp() {
		this(new Params());
	}

	public HugeIndexerStringPredictBatchOp(Params params) {
		super(params);
	}


	/**
	 * Extract the token to index mapping from the model. The <code>selectedCols</code> should be a subset
	 * of those columns used to train the model.
	 *
	 * @param model        The model fitted by {@link StringIndexerTrainBatchOp}.
	 * @return A DataSet of tuples of column index, token, token index.
	 */
	private DataSet <Tuple2 <String, Long>> getModelData(BatchOperator model) {
		DataSet <Row> modelRows = model.getDataSet();
		return modelRows
			.flatMap(new FlatMapFunction <Row, Tuple2 <String, Long>>() {
				private static final long serialVersionUID = 7697943140162154366L;

				@Override
				public void flatMap(Row row, Collector <Tuple2 <String, Long>> out) throws Exception {
					out.collect(Tuple2.of((String) row.getField(0), (Long) row.getField(1)));
				}
			})
			.name("get_model_data")
			.returns(new TupleTypeInfo <>(Types.STRING, Types.LONG));
	}

	@Override
	public HugeIndexerStringPredictBatchOp linkFrom(BatchOperator <?>... inputs) {
		Params params = super.getParams();
		BatchOperator model = inputs[0];
		BatchOperator data = inputs[1];

		String[] selectedColNames = params.get(HugeMultiStringIndexerPredictParams.SELECTED_COLS);
		String[] outputColNames = params.get(HugeMultiStringIndexerPredictParams.OUTPUT_COLS);
		if (outputColNames == null) {
			outputColNames = selectedColNames;
		}
		String[] keepColNames = params.get(HugeMultiStringIndexerPredictParams.RESERVED_COLS);
		TypeInformation[] outputColTypes = new TypeInformation[outputColNames.length];
		TypeInformation[] inputColTypes = TableUtil.findColTypesWithAssert(data.getSchema(), selectedColNames);
		Arrays.fill(outputColTypes, Types.STRING);

		OutputColsHelper outputColsHelper = new OutputColsHelper(data.getSchema(), outputColNames,
			outputColTypes, keepColNames);

		final int[] selectedColIdx = TableUtil.findColIndicesWithAssertAndHint(data.getSchema(), selectedColNames);
		final HandleInvalid handleInvalidStrategy
			= HandleInvalid
			.valueOf(params.get(HugeMultiStringIndexerPredictParams.HANDLE_INVALID).toString());

		DataSet <Tuple2 <Long, Row>> dataWithId = DataSetUtils.zipWithUniqueId(data.getDataSet());

		DataSet <Tuple2 <String, Long>> modelData = getModelData(model);

		// tuple: record id, column index, column array index, token
		DataSet <Tuple4 <Long, Integer, Integer, Long>> flattened = dataWithId
			.flatMap(new RichFlatMapFunction <Tuple2 <Long, Row>, Tuple4 <Long, Integer, Integer, Long>>() {
				private static final long serialVersionUID = -8382461068855755626L;

				@Override
				public void flatMap(Tuple2 <Long, Row> value, Collector <Tuple4 <Long, Integer, Integer, Long>> out)
					throws Exception {
					for (int i = 0; i < selectedColIdx.length; i++) {
						Object o = value.f1.getField(selectedColIdx[i]);
						if (null == o) {
							out.collect(Tuple4.of(value.f0, i, 0, -1L));
						} else {
							if (inputColTypes[i].isBasicType()) {
								out.collect(Tuple4.of(value.f0, i, 0, (Long) o));
							} else {
								Long[] ids = (Long[]) o;
								for (int j = 0; j < ids.length; j++) {
									out.collect(Tuple4.of(value.f0, i, j, ids[j]));
								}
							}
						}
					}
				}
			})
			.name("flatten_pred_data")
			.returns(new TupleTypeInfo <>(Types.LONG, Types.INT, Types.INT, Types.LONG));

		// record id, column index, token index
		DataSet <Tuple4 <Long, Integer, Integer, String>> indexed = flattened
			.leftOuterJoin(modelData)
			.where(3).equalTo(1)
			.with(
				new JoinFunction <Tuple4 <Long, Integer, Integer, Long>, Tuple2 <String, Long>, Tuple4 <Long, Integer, Integer, String>>() {
					private static final long serialVersionUID = 2270459281179536013L;

					@Override
					public Tuple4 <Long, Integer, Integer, String> join(Tuple4 <Long, Integer, Integer, Long> first,
															 Tuple2 <String, Long> second) throws Exception {
						if (second == null) {
							return Tuple4.of(first.f0, first.f1, first.f2, "notFound");
						} else {
							return Tuple4.of(first.f0, first.f1, first.f2, second.f0);
						}
					}
				})
			.name("map_index_to_token")
			.returns(new TupleTypeInfo <>(Types.LONG, Types.INT, Types.INT, Types.STRING));

		// tuple: record id, prediction result
		DataSet <Tuple2 <Long, Row>> aggregateResult = indexed
			.groupBy(0)
			.reduceGroup(new GroupReduceFunction <Tuple4 <Long, Integer, Integer, String>, Tuple2 <Long, Row>>() {
				private static final long serialVersionUID = -1581264399340055162L;

				@Override
				public void reduce(Iterable <Tuple4 <Long, Integer, Integer, String>> values, Collector <Tuple2 <Long, Row>> out)
					throws Exception {

					Long id = null;
					Row r = new Row(selectedColIdx.length);
					ArrayList<Tuple3<Integer, Integer, String>> list = new ArrayList <>();
					for (Tuple4 <Long, Integer, Integer, String> v : values) {
						list.add(Tuple3.of(v.f1, v.f2, v.f3));
						id = v.f0;
					}
					list.sort(new Comparator <Tuple3 <Integer, Integer, String>>() {
						@Override
						public int compare(Tuple3 <Integer, Integer, String> o1,
										   Tuple3 <Integer, Integer, String> o2) {
							if (o1.f0.equals(o2.f0)) {
								return o1.f1.compareTo(o2.f1);
							}
							return o1.f0.compareTo(o2.f0);
						}
					});
					ArrayList<String> allFeatures = new ArrayList <>(list.size());
					for (Tuple3 <Integer, Integer, String> v : list) {
						allFeatures.add(v.f2);
					}
					String[] originFeatures = new String[selectedColIdx.length];
					int startIndex = 0, endIndex = 0;
					int lastIndex = 0;
					for (int i = 0; i < list.size(); i++) {
						Tuple3<Integer, Integer, String> v = list.get(i);
						if (lastIndex != v.f0) {
							originFeatures[lastIndex] = StringUtils.join(allFeatures.subList(startIndex, endIndex), ",");
							lastIndex = v.f0;
							startIndex = i;
							endIndex = i;
						}
						endIndex += 1;
						if (i == list.size() - 1) {
							originFeatures[lastIndex] = StringUtils.join(allFeatures.subList(startIndex, endIndex), ",");
						}
					}
					for (int i = 0; i < originFeatures.length; i++) {
						r.setField(i, originFeatures[i]);
					}
					out.collect(Tuple2.of(id, r));
				}
			})
			.name("aggregate_result")
			.returns(new TupleTypeInfo <>(Types.LONG, new RowTypeInfo(outputColTypes)));

		DataSet <Row> output = dataWithId
			.join(aggregateResult)
			.where(0).equalTo(0)
			.with(new JoinFunction <Tuple2 <Long, Row>, Tuple2 <Long, Row>, Row>() {
				private static final long serialVersionUID = 3724539437313089427L;

				@Override
				public Row join(Tuple2 <Long, Row> first, Tuple2 <Long, Row> second) throws Exception {
					return outputColsHelper.getResultRow(first.f1, second.f1);
				}
			})
			.name("merge_result")
			.returns(new RowTypeInfo(outputColsHelper.getResultSchema().getFieldTypes()));

		this.setOutput(output, outputColsHelper.getResultSchema());
		return this;
	}
}
